#include <arch/interrupt.h>
#include <arch/memio.h>
#include <os/virmem.h>
#include <os/debug.h>
#include <os/mutexlock.h>
#include <os/memcache.h>
#include <os/mutexlock.h>
#include <lib/bitmap.h>
#include <lib/list.h>
#include<lib/assert.h>

static bitmap_t virbase_bitsmap;
static uint32_t virbase;
static LIST_HEAD(using_vir_mem_list_head);
static LIST_HEAD(free_vir_mem_list_head);
static DEFINE_MUTEX_LOCK(vir_mem_lock);

void VirMemInit()
{
    // init vir mem bitsmap
    virbase_bitsmap.length = DYNAMIC_MAP_MEM_SIZE / PAGE_SIZE * 8;
    assert(virbase_bitsmap.length);
    virbase_bitsmap.bits = KMemAlloc(virbase_bitsmap.length);
    if(!virbase_bitsmap.bits)
        return;
    init_bitmap(&virbase_bitsmap);
    // set virbase
    virbase = PageAlign(VIR_MEM_BASE);
    // init vir mem object list
    list_init(&using_vir_mem_list_head);
    list_init(&free_vir_mem_list_head);
    #ifdef VIRMEM_DEBUG
    KPrint("debug: bitmap len %d bits %x virbase %x\n",(uint32_t)virbase_bitsmap.length,(uint32_t)virbase_bitsmap.bits,(uint32_t)VIR_MEM_BASE);
    #endif
    KPrint(PRINT_INFO "[virmem] virmem init finished!\n");
}

address_t VirBaseAlloc(size_t size)
{
    uint32_t pages=0;
    int idx=-1;

    #ifdef VIRMEM_DEBUG
    KPrint("[virmem] alloc size %d[%d] pages %d\n",size,PageAlign(size),PageAlign(size)/PAGE_SIZE);
    #endif 

    // page align
    size = PageAlign(size);
    if (!size)
        return 0;
    pages = size / PAGE_SIZE;

    idx = bitmap_find_nfree(&virbase_bitsmap, pages);
    if (idx != -1)
    {
        // set bitsmap
        MutexlockLock(&vir_mem_lock, MUTEX_LOCK_MODE_BLOCK);
        for (int i = 0; i < pages; i++)
        {
            bitmap_set(&virbase_bitsmap, 1, idx + i);
            assert(bitmap_test(&virbase_bitsmap,idx+i));
        }
        MutexlockUnlock(&vir_mem_lock);
        #ifdef VIRMEM_DEBUG
        KPrint("[virmem] alloc pages %d idx %d return addr %x\n",pages,idx,virbase+idx*PAGE_SIZE);
        #endif 
        return virbase + idx * PAGE_SIZE;
    }
    #ifdef VIRMEM_DEBUG
    KPrint("[virmem] alloc pages %d failed\n",pages);
    #endif
    return -1;
}

int VirBaseFree(uint64_t vbase, uint64_t size)
{
    uint64_t pages;
    int idx;

    if (!size)
        return -1;
    size = PageAlign(size);
    pages = size / PAGE_SIZE;
    // figure idx
    idx = (vbase - virbase) / PAGE_SIZE;
    if (idx != -1)
    {
        MutexlockLock(&vir_mem_lock, MUTEX_LOCK_MODE_BLOCK);
        // clear bitsmap
        for (uint64_t i = 0; i < pages; i++)
        {
            bitmap_set(&virbase_bitsmap, 0, i);
        }
        MutexlockUnlock(&vir_mem_lock);
        return 0;
    }
    return -1;
}

void *VirMemDoAlloc(size_t size)
{
    vir_mem_t *area;
    int64_t start = VirBaseAlloc(size);
    if (!start)
        return NULL;

    area = KMemAlloc(sizeof(vir_mem_t));
    if (area != NULL)
    {
        area->addr = start;
        area->size = size;
        // map to pypage faild
        if (VMemMap(start, size, PROTE_KERNEL | PROTE_WRITE) < 0)
        {
            VirBaseFree(start, size);
            KMemFree(area);
            return NULL;
        }
        MutexlockLock(&vir_mem_lock, MUTEX_LOCK_MODE_BLOCK);
        list_add_tail(&area->list, &using_vir_mem_list_head);
        MutexlockUnlock(&vir_mem_lock);
        return PTYPE(area->addr);
    }
    VirBaseFree(start, size);
    return NULL;
}

void *VirMemAlloc(size_t size)
{
    vir_mem_t *vir = NULL, *area = NULL;

    size = PageAlign(size);
    if (!size)
        return NULL;

    // from free list search a fit area
    list_traversal_all_owner_to_next(area, &free_vir_mem_list_head, list)
    {
        if (size >= area->size)
        {
            vir = area;
            break;
        }
    }
    if (vir != NULL)
    {
        MutexlockLock(&vir_mem_lock, MUTEX_LOCK_MODE_BLOCK);
        list_del(&vir->list);
        list_add_tail(&vir->list, &using_vir_mem_list_head);
        MutexlockUnlock(&vir_mem_lock);
        return PTYPE(vir->addr);
    }
    return VirMemDoAlloc(size);
}

int VirMemDoFree(vir_mem_t *vir)
{
    vir_mem_t *area=NULL;
    int flags=-1;

    list_del(&vir->list);

    // insert to free list
    // list empty,insert to list head
    if (list_empty(&free_vir_mem_list_head))
    {
        list_add_head(&vir->list, &free_vir_mem_list_head);
        flags=1;
    }
    else
    {   
        // list no empty
        area = list_first_owner(&free_vir_mem_list_head, vir_mem_t, list);
        do
        {
            // vir obj size > area obj size
            if (area->size < vir->size)
            {
                // if area is last node,direct add to list tail
                if (list_is_last(&area->list, &free_vir_mem_list_head))
                {
                    list_add_tail(&vir->list, &free_vir_mem_list_head);
                    flags=1;
                    break;
                }
                // get next own
                area = list_next_owner(area, list);
            }
            else
            {
                // add to area before
                list_add_before(&vir->list, &area->list);
                flags=1;
                break;
            }
        } while (!list_is_head(&area->list)); // area->list must be no equal head
    }
    return flags;
}

// free virmem by ptr from using vir mem list
int VirMemFree(void *ptr)
{
    uint64_t addr;
    vir_mem_t *vir, *area;

    if (!ptr)
        return -1;
    addr = ADDR(ptr);
    if (addr < virbase || addr >= VIR_MEM_END)
        return -1;
    list_traversal_all_owner_to_next(area, &using_vir_mem_list_head, list)
    {
        if (area->addr >= addr && addr <= area->addr + area->size)
        {
            vir = area;
            break;
        }
    }
    if (vir != NULL)
    {
        #ifdef VIRMEM_DEBUG
        KPrint("[virmem] free: ptr %x area vbase %x size %x\n",ptr,(uint32_t)vir->addr,(uint32_t)vir->size);
        #endif
        if (!VirMemDoFree(vir))
            return 0;
    }
    return -1;
}

void *MemIoReMap(uint64_t pybase, size_t size)
{
    uint64_t vbase;
    vir_mem_t *area;

    if (!pybase || !size)
        return NULL;
        
    vbase = VirBaseAlloc(size);
    if (vbase != -1)
    {
        area = KMemAlloc(sizeof(vir_mem_t));
        if (area != NULL)
        {
            MutexlockLock(&vir_mem_lock, MUTEX_LOCK_MODE_BLOCK);
            area->addr = vbase;
            area->size = size;
            list_add_tail(&area->list, &using_vir_mem_list_head);
            // remap
            HalMemIoReMap(vbase, pybase, size);
            MutexlockUnlock(&vir_mem_lock);
            return PTYPE(vbase);
        }
        VirBaseFree(vbase, size);
    }
    return NULL;
}

int MemIoUnMap(void *vbase)
{
    uint64_t addr = (uint64_t)(uint32_t)vbase;
    vir_mem_t *vir, *area;

    if (!vbase)
        return -1;
    if (addr < VIR_MEM_BASE || addr > VIR_MEM_END)
        return -1;
    // search targe area object from using vir mem list
    list_traversal_all_owner_to_next(area, &using_vir_mem_list_head, list)
    {
        if (area->addr >= addr && addr <= area->addr + area->size)
        {
            vir = area;
            break;
        }
    }
    if (vir != NULL)
    {
        HalMemIoUnMap(vir->addr, vir->size);
        return 0;
    }
    return -1;
}

void VirMemAreaDump()
{
    vir_mem_t *area;
    KPrint("[virmem dump view]");
    KPrint("--------free vir mem list------------\n");
    list_traversal_all_owner_to_next(area,&free_vir_mem_list_head,list)
    {
        KPrint("area addr %d size %d\n",area->addr,area->size);
    }
    KPrint("-------used vir mem list-------------\n");
    list_traversal_all_owner_to_next(area,&using_vir_mem_list_head,list)
    {
        KPrint("area addr %x size %d\n",area->addr,area->size);
    }
}